Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: prevent accumulation of SelectFields in PyTorchPredictor #2951

Merged
merged 3 commits into from
Aug 7, 2023

Conversation

Cameronwood611
Copy link
Contributor

Issue #, if available: #2947

Description of changes: Have simple check to not accumulate the same field names, resulting in a RecursionError.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: bug fix

@abdulfatir @lostella Please review as this relates to changes you (co-)authored

@lostella
Copy link
Contributor

lostella commented Aug 1, 2023

@Cameronwood611 thanks for this; it seems like this fix is breaking the behaviour of models, so I'm not sure it's the right way to go. Is there any way you could share the model you're using for #2947, for debugging purposes? Or even a simpler, smaller dummy model that just runs into the issue. What kind of model are we talking about?

@Cameronwood611
Copy link
Contributor Author

Cameronwood611 commented Aug 1, 2023

@Cameronwood611 thanks for this; it seems like this fix is breaking the behaviour of models, so I'm not sure it's the right way to go. Is there any way you could share the model you're using for #2947, for debugging purposes? Or even a simpler, smaller dummy model that just runs into the issue. What kind of model are we talking about?

Yes! I updated the issue linked but I'll also post the example here --sorry about that. I noticed I had to raise the for loop above 1000 iterations (around the default recursion depth) for it to break; not sure if that's related to different model params, not loading in the model, etc. I'm not sure of the side effects this change has but maybe we could look into a certain filter or adding SelectFields into a set? Would be willing to help with this effort.

import random
import numpy as np
import pandas as pd
from gluonts.torch import PyTorchPredictor
from gluonts.dataset.common import ListDataset, DataEntry
from gluonts.torch import TemporalFusionTransformerEstimator


estimator = TemporalFusionTransformerEstimator(
freq='S',
context_length=10,
prediction_length=5,
trainer_kwargs={'max_epochs': 5}
)
train = ListDataset([{'target': np.array([random.uniform(-1, 1) for _ in range(50)]), 'start': pd.Period('01-01-2023', freq='S')} for _ in range(20)], freq='S')

pred = ListDataset([{'target': np.array([random.uniform(-1, 1) for _ in range(50)]), 'start': pd.Period('01-01-2023', freq='S')} for _ in range(20)], freq='S')
model = estimator.train(train, pred)

x, y, z = np.array([x for x in range(10)]), np.array([y for y in range (10)]), np.array([z for z in range(10)]) #for example
start = pd.Period('01-01-2023', freq='S')
_input = ListDataset([{'target': x, 'start': start},
                          {'target': y, 'start': start},
                          {'target': z, 'start': start}], freq='S')
for i in range(5000): #used to be 1000 for it to break..
    pred = model.predict(_input)
    output = [sig.mean for sig in pred]  #this calls on a generator which causes the recursion error 

self.input_transform += SelectFields(
self.input_names + self.required_fields, allow_missing=True
)
self.provided_fields = True
inference_data_loader = InferenceDataLoader(
dataset,
transform=self.input_transform,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cameronwood611 the problem is with the += which updates the predictor state every time .predict is invoked. Solution should be to either do here

input_transform = self.input_transform + SelectFields(...)
inference_data_loader = InferenceDataLoader(
    ...,
    transform=input_transform,
    ....,
)

or directly in the constructor

self.input_transform = input_transform + SelectFields(...)

and remove lines 76-78 from here.

@lostella
Copy link
Contributor

lostella commented Aug 3, 2023

@Cameronwood611 thanks for the bug hunting and the fix, if you update it as suggested then it should work fine, and we'll backport it and release it soon

@lostella lostella added bug fix (one of pr required labels) pending v0.13.x backport This contains a fix to be backported to the v0.13.x branch labels Aug 3, 2023
@lostella lostella changed the title Prevent redundant accumulation of SelectFields Fix: prevent accumulation of SelectFields in PyTorchPredictor Aug 7, 2023
@lostella lostella merged commit a944581 into awslabs:dev Aug 7, 2023
17 of 21 checks passed
lostella added a commit to lostella/gluonts that referenced this pull request Aug 7, 2023
…slabs#2951)

* Prevent redundant accumulation of fields

* update fix

---------

Co-authored-by: Cameronwood611 <cwood611@uab.edu>
Co-authored-by: Lorenzo Stella <stellalo@amazon.com>
@lostella lostella mentioned this pull request Aug 7, 2023
@Cameronwood611
Copy link
Contributor Author

@Cameronwood611 thanks for the bug hunting and the fix, if you update it as suggested then it should work fine, and we'll backport it and release it soon

Thank you!

lostella added a commit that referenced this pull request Aug 7, 2023
* Fix JsonLinesFile slicing. (#2925)

* Zebras: Fix index handling of SplitFrame.resize. (#2938)

* Docs: fix missing values use-case in `PandasDataset` docs (#2941)

* Ignore F403 errors in preludes. (#2948)

* Fix: prevent accumulation of `SelectFields` in `PyTorchPredictor` (#2951)

* Prevent redundant accumulation of fields

* update fix

---------

Co-authored-by: Cameronwood611 <cwood611@uab.edu>
Co-authored-by: Lorenzo Stella <stellalo@amazon.com>

* [Docs] fix link to NPTS implementation (#2953)

* Revert "Fix JsonLinesFile slicing. (#2925)"

This reverts commit fa7f9a0.

---------

Co-authored-by: Jasper <schjaspe@amazon.de>
Co-authored-by: cneely33 <cneely33@gmail.com>
Co-authored-by: cameronwood611 <cameron.wood611@gmail.com>
Co-authored-by: Cameronwood611 <cwood611@uab.edu>
@lostella lostella removed the pending v0.13.x backport This contains a fix to be backported to the v0.13.x branch label Aug 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug fix (one of pr required labels)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants